import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Set global font sizes
plt.rcParams['font.size'] = 14  # Base font size
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 12

def plot_yearly_events(df):
    """Plot yearly event distribution"""
    plt.figure(figsize=(15, 6))
    
    # Yearly counts by split
    yearly_counts = df.groupby([df['date'].dt.year, 'is_train']).size().unstack()
    yearly_counts.columns = ['Train', 'Test']
    yearly_counts = yearly_counts.fillna(0)
    
    # Plot stacked bars
    plt.bar(yearly_counts.index, yearly_counts['Train'], 
            label='Train', color='#2ecc71', alpha=0.7)
    plt.bar(yearly_counts.index, yearly_counts['Test'], 
            bottom=yearly_counts['Train'], 
            label='Test', color='#e74c3c', alpha=0.7)
    
    # Add split line
    plt.axvline(x=2017, color='black', linestyle='--', 
                alpha=0.5, label='Train/Test Split')
    
    plt.title('Number of Events per Year (1970-2020)', pad=20, fontsize=18)
    plt.xlabel('Year', fontsize=14)
    plt.ylabel('Number of Events', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('yearly_events.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_monthly_totals(df):
    """Plot total events by month"""
    plt.figure(figsize=(15, 6))
    
    # Monthly totals
    monthly_totals = df.groupby(['is_train', df['date'].dt.month]).size().unstack()
    monthly_totals.columns = range(1, 13)
    
    # Plot lines
    plt.plot(range(1, 13), monthly_totals.loc[True], 'o-', 
             label='Train', color='#2ecc71', alpha=0.7, markersize=8)
    plt.plot(range(1, 13), monthly_totals.loc[False], 'o-', 
             label='Test', color='#e74c3c', alpha=0.7, markersize=8)
    
    plt.title('Total Events by Month', pad=20, fontsize=18)
    plt.xlabel('Month', fontsize=14)
    plt.ylabel('Total Number of Events', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)
    
    plt.xticks(range(1, 13), 
               ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 
                'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
    
    plt.tight_layout()
    plt.savefig('monthly_totals.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_attack_type_evolution(df):
    """Plot evolution of attack types over time"""
    plt.figure(figsize=(15, 8))
    
    # Get top 5 attack types
    top_attacks = df['attacktype1_txt'].value_counts().nlargest(5).index
    
    # Calculate yearly percentages for each attack type
    yearly_props = df[df['attacktype1_txt'].isin(top_attacks)].groupby(
        [df['date'].dt.year, 'attacktype1_txt']
    ).size().unstack()
    
    # Convert to percentages
    yearly_props = yearly_props.div(yearly_props.sum(axis=1), axis=0) * 100
    
    # Plot lines for each attack type
    colors = ['#2ecc71', '#e74c3c', '#3498db', '#f1c40f', '#9b59b6']
    for attack, color in zip(top_attacks, colors):
        plt.plot(yearly_props.index, yearly_props[attack], 
                label=attack, color=color, alpha=0.7, linewidth=2)
    
    plt.axvline(x=2017, color='black', linestyle='--', 
                alpha=0.5, label='Train/Test Split')
    
    plt.title('Evolution of Attack Types (1970-2020)', pad=20, fontsize=18)
    plt.xlabel('Year', fontsize=14)
    plt.ylabel('Percentage of Events', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('attack_evolution.png', dpi=300, bbox_inches='tight')
    plt.close()

def plot_multi_label_trends(df):
    """Plot trends in multi-label events"""
    plt.figure(figsize=(15, 6))
    
    # Calculate yearly percentage of multi-label events
    yearly_stats = df.groupby(df['date'].dt.year).agg({
        'attacktype2_txt': lambda x: (x.notna().sum() / len(x)) * 100,
        'attacktype3_txt': lambda x: (x.notna().sum() / len(x)) * 100
    })
    
    plt.plot(yearly_stats.index, yearly_stats['attacktype2_txt'], 
             label='Secondary Attack', color='#2ecc71', alpha=0.7, linewidth=2)
    plt.plot(yearly_stats.index, yearly_stats['attacktype3_txt'], 
             label='Tertiary Attack', color='#e74c3c', alpha=0.7, linewidth=2)
    
    plt.axvline(x=2017, color='black', linestyle='--', 
                alpha=0.5, label='Train/Test Split')
    
    plt.title('Evolution of Multi-Label Events (1970-2020)', pad=20, fontsize=18)
    plt.xlabel('Year', fontsize=14)
    plt.ylabel('Percentage of Events', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('multilabel_trends.png', dpi=300, bbox_inches='tight')
    plt.close()

def main():
    print("Loading data...")
    df = pd.read_excel('globalterrorismdb_0522dist.xlsx', engine='openpyxl')
    
    print("Processing dates...")
    df['imonth'] = df['imonth'].fillna(1).astype(int)
    df['iday'] = df['iday'].fillna(1).astype(int)
    df['date'] = pd.to_datetime({
        'year': df['iyear'],
        'month': df['imonth'].clip(1, 12),
        'day': df['iday'].clip(1, 31)
    })
    df['is_train'] = df['date'] < pd.Timestamp('2017-01-01')
    
    print("Creating visualizations...")
    plot_yearly_events(df)
    plot_monthly_totals(df)
    plot_attack_type_evolution(df)
    plot_multi_label_trends(df)
    
    print("Visualizations saved as:")
    print("- yearly_events.png")
    print("- monthly_totals.png")
    print("- attack_evolution.png")
    print("- multilabel_trends.png")

if __name__ == "__main__":
    main()